{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2021 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Accelerate BERT encoder with TF-TRT\n",
    "\n",
    "\n",
    "## Introduction\n",
    "\n",
    "The NVIDIA TensorRT is a C++ library that facilitates high performance inference on NVIDIA graphics processing units (GPUs). TensorFlow™ integration with TensorRT™ (TF-TRT) optimizes TensorRT compatible parts of your computation graph, allowing TensorFlow to execute the remaining graph. While you can use TensorFlow's wide and flexible feature set, TensorRT will produce a highly optimized runtime engine for the TensorRT compatible subgraphs of your network.\n",
    "\n",
    "In this notebook, we demonstrate accelerating BERT inference using TF-TRT. We focus on the encoder.\n",
    "\n",
    "## Requirements\n",
    "This notebook requires at least TF 2.5 and TRT 7.1.3."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Download the model\n",
    "We will download a bert base model from [TF-Hub](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: You are using pip version 21.0; however, version 21.0.1 is available.\r\n",
      "You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.\u001b[0m\r\n"
     ]
    }
   ],
   "source": [
    "!pip install -q tf-models-official"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Enabling eager execution\n",
      "INFO:tensorflow:Enabling v2 tensorshape\n",
      "INFO:tensorflow:Enabling resource variables\n",
      "INFO:tensorflow:Enabling tensor equality\n",
      "INFO:tensorflow:Enabling control flow v2\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfhub_handle_encoder = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3'\n",
    "bert_saved_model_path = 'bert_base'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 910). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: bert_base/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: bert_base/assets\n"
     ]
    }
   ],
   "source": [
    "bert_model = hub.load(tfhub_handle_encoder)\n",
    "tf.saved_model.save(bert_model, bert_saved_model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Inference\n",
    "In this section we will convert the model using TF-TRT and run inference. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from tensorflow.python.saved_model import signature_constants\n",
    "from tensorflow.python.saved_model import tag_constants\n",
    "from tensorflow.python.compiler.tensorrt import trt_convert as trt\n",
    "from timeit import default_timer as timer\n",
    "\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 Helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_func_from_saved_model(saved_model_dir):\n",
    "    saved_model_loaded = tf.saved_model.load(\n",
    "        saved_model_dir, tags=[tag_constants.SERVING])\n",
    "    graph_func = saved_model_loaded.signatures[\n",
    "        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n",
    "    return graph_func, saved_model_loaded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_and_benchmark_throughput(input_dict, model, N_warmup_run=50, N_run=500,\n",
    "                                     result_key='predictions', batch_size=None):\n",
    "    elapsed_time = []\n",
    "    \n",
    "    for val in input_dict.values():\n",
    "        input_batch_size = val.shape[0]\n",
    "        break\n",
    "    if batch_size is None or batch_size > input_batch_size:\n",
    "        batch_size = input_batch_size\n",
    "        \n",
    "    print('Benchmarking with batch size', batch_size)\n",
    "    \n",
    "    elapsed_time = np.zeros(N_run)\n",
    "    for i in range(N_warmup_run):                                             \n",
    "        preds = model(**input_dict)\n",
    "    \n",
    "    # Force device synchronization with .numpy()\n",
    "    tmp = preds[result_key][0].numpy() \n",
    "    \n",
    "    for i in range(N_run):\n",
    "        start_time = timer()\n",
    "        preds = model(**input_dict)\n",
    "        # Synchronize\n",
    "        tmp += preds[result_key][0].numpy() \n",
    "        end_time = timer()\n",
    "        elapsed_time[i] = end_time - start_time\n",
    "\n",
    "        if i>=50 and i % 50 == 0:\n",
    "            print('Steps {}-{} average: {:4.1f}ms'.format(i-50, i, (elapsed_time[i-50:i].mean()) * 1000))\n",
    "\n",
    "    latency = elapsed_time.mean() * 1000\n",
    "    print('Latency: {:5.2f}+/-{:4.2f}ms'.format(latency, elapsed_time.std() * 1000))\n",
    "    print('Throughput: {:.0f} samples/s'.format(N_run * batch_size / elapsed_time.sum()))\n",
    "    return latency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trt_convert(input_path, output_path, input_shapes, explicit_batch=False,\n",
    "                dtype=np.float32, precision='FP32', prof_strategy='Optimal'):\n",
    "    conv_params=trt.TrtConversionParams(\n",
    "        precision_mode=precision, minimum_segment_size=50,\n",
    "        max_workspace_size_bytes=12*1<<30, maximum_cached_engines=1)\n",
    "    converter = trt.TrtGraphConverterV2(\n",
    "        input_saved_model_dir=input_path, conversion_params=conv_params,\n",
    "        use_dynamic_shape=explicit_batch,\n",
    "        dynamic_shape_profile_strategy=prof_strategy)\n",
    "\n",
    "    converter.convert()\n",
    "\n",
    "    def input_fn():\n",
    "        for shapes in input_shapes:\n",
    "            # return a list of input tensors\n",
    "            yield [np.ones(shape=x).astype(dtype) for x in shapes]\n",
    "\n",
    "    converter.build(input_fn)\n",
    "    converter.save(output_path)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_input(batch_size, seq_length):\n",
    "    # Generate random input data\n",
    "    mask = tf.convert_to_tensor(np.ones((batch_size, seq_length), dtype=np.int32))\n",
    "    type_id = tf.convert_to_tensor(np.zeros((batch_size, seq_length), dtype=np.int32))\n",
    "    word_id = tf.convert_to_tensor(np.random.randint(0, 1000, size=[batch_size, seq_length], dtype=np.int32))\n",
    "    return {'input_mask':mask, 'input_type_ids': type_id, 'input_word_ids':word_id}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Convert the model with TF-TRT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 910). These functions will not be directly callable after loading.\n"
     ]
    }
   ],
   "source": [
    "bert_trt_path = bert_saved_model_path + '_trt'\n",
    "input_shapes = [[(1, 128), (1, 128), (1, 128)]] \n",
    "trt_convert(bert_saved_model_path, bert_trt_path, input_shapes, True, np.int32, precision='FP16')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Run inference with converted model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "trt_func, _ = get_func_from_saved_model(bert_trt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Benchmarking with batch size 1\n",
      "Steps 0-50 average:  4.6ms\n",
      "Steps 50-100 average:  4.6ms\n",
      "Steps 100-150 average:  4.6ms\n",
      "Steps 150-200 average:  4.6ms\n",
      "Steps 200-250 average:  4.5ms\n",
      "Steps 250-300 average:  4.5ms\n",
      "Steps 300-350 average:  4.5ms\n",
      "Steps 350-400 average:  4.5ms\n",
      "Steps 400-450 average:  4.5ms\n",
      "Latency:  4.54+/-0.24ms\n",
      "Throughput: 220 samples/s\n"
     ]
    }
   ],
   "source": [
    "input_dict = random_input(1, 128)\n",
    "result_key = 'bert_encoder_1' # 'classifier'\n",
    "res = predict_and_benchmark_throughput(input_dict, trt_func, result_key=result_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare to the original function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Benchmarking with batch size 1\n",
      "Steps 0-50 average:  8.5ms\n",
      "Steps 50-100 average:  9.0ms\n",
      "Steps 100-150 average:  8.5ms\n",
      "Steps 150-200 average:  8.6ms\n",
      "Steps 200-250 average:  8.7ms\n",
      "Steps 250-300 average: 10.1ms\n",
      "Steps 300-350 average:  8.6ms\n",
      "Steps 350-400 average:  9.2ms\n",
      "Steps 400-450 average:  8.5ms\n",
      "Latency:  8.84+/-0.86ms\n",
      "Throughput: 113 samples/s\n"
     ]
    }
   ],
   "source": [
    "func, model = get_func_from_saved_model(bert_saved_model_path)\n",
    "res = predict_and_benchmark_throughput(input_dict, func, result_key=result_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Dynamic sequence length\n",
    "The sequence length for the encoder is dynamic, we can use different input sequence lengths. Here we call the original model for two sequences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq1 = random_input(1, 128)\n",
    "res1 = func(**seq1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq2 = random_input(1, 180)\n",
    "res2 = func(**seq2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The converted model is optimized for a sequnce length of 128 (and batch size 8). If we infer the converted model using a different sequence length, then two things can happen:\n",
    "1. If `TrtConversionParams.allow_build_at_runtime` == False: native TF model is inferred\n",
    "2. if `TrtConversionParams.allow_build_at_runtime` == True a new TRT engine is created which is optimized for the new sequence length. \n",
    "\n",
    "The first option do not provide TRT accelaration while the second one creates a large overhead while the new engine is constructed. In the next section we convert the model to handle multiple sequence lengths.\n",
    "\n",
    "### 3.1 TRT Conversion with dynamic sequence length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 910). These functions will not be directly callable after loading.\n"
     ]
    }
   ],
   "source": [
    "bert_trt_path = bert_saved_model_path + '_trt2'\n",
    "input_shapes = [[(1, 128), (1, 128), (1, 128)], [(1, 180), (1, 180), (1, 180)]] \n",
    "trt_convert(bert_saved_model_path, bert_trt_path, input_shapes, True, np.int32, precision='FP16',\n",
    "            prof_strategy='Range')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "trt_func_dynamic, _ = get_func_from_saved_model(bert_trt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "trt_res = trt_func_dynamic(**seq1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_key = 'bert_encoder_1' # 'classifier'\n",
    "res = predict_and_benchmark_throughput(seq1, trt_func_dynamic, result_key=result_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = predict_and_benchmark_throughput(seq2, trt_func_dynamic, result_key=result_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
